{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-GPU Training Example\n",
    "\n",
    "Train a convolutional neural network on multiple GPU with TensorFlow 2.0+.\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with multiple GPU cards\n",
    "\n",
    "In this example, we are using data parallelism to split the training accross multiple GPUs. Each GPU has a full replica of the neural network model, and the weights (i.e. variables) are updated synchronously by waiting that each GPU process its batch of data.\n",
    "\n",
    "First, each GPU process a distinct batch of data and compute the corresponding gradients, then, all gradients are accumulated in the CPU and averaged. The model weights are finally updated with the gradients averaged, and the new model weights are sent back to each GPU, to repeat the training process.\n",
    "\n",
    "<img src=\"https://www.tensorflow.org/images/Parallelism.png\" alt=\"Parallelism\" style=\"width: 400px;\"/>\n",
    "\n",
    "## CIFAR10 Dataset Overview\n",
    "\n",
    "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.\n",
    "\n",
    "![CIFAR10 Dataset](https://storage.googleapis.com/kaggle-competitions/kaggle/3649/media/cifar-10.png)\n",
    "\n",
    "More info: https://www.cs.toronto.edu/~kriz/cifar.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import Model, layers\n",
    "import time\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MNIST dataset parameters.\n",
    "num_classes = 10 # total classes (0-9 digits).\n",
    "num_gpus = 4\n",
    "\n",
    "# Training parameters.\n",
    "learning_rate = 0.001\n",
    "training_steps = 1000\n",
    "# Split batch size equally between GPUs.\n",
    "# Note: Reduce batch size if you encounter OOM Errors.\n",
    "batch_size = 1024 * num_gpus\n",
    "display_step = 20\n",
    "\n",
    "# Network parameters.\n",
    "conv1_filters = 64 # number of filters for 1st conv layer.\n",
    "conv2_filters = 128 # number of filters for 2nd conv layer.\n",
    "conv3_filters = 256 # number of filters for 2nd conv layer.\n",
    "fc1_units = 2048 # number of neurons for 1st fully-connected layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare MNIST data.\n",
    "from tensorflow.keras.datasets import cifar10\n",
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "# Convert to float32.\n",
    "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n",
    "# Normalize images value from [0, 255] to [0, 1].\n",
    "x_train, x_test = x_train / 255., x_test / 255.\n",
    "y_train, y_test = np.reshape(y_train, (-1)), np.reshape(y_test, (-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use tf.data API to shuffle and batch data.\n",
    "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "train_data = train_data.repeat().shuffle(batch_size * 10).batch(batch_size).prefetch(num_gpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(Model):\n",
    "    # Set layers.\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        # Convolution Layer with 64 filters and a kernel size of 3.\n",
    "        self.conv1_1 = layers.Conv2D(conv1_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        self.conv1_2 = layers.Conv2D(conv1_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n",
    "        self.maxpool1 = layers.MaxPool2D(2, strides=2)\n",
    "\n",
    "        # Convolution Layer with 128 filters and a kernel size of 3.\n",
    "        self.conv2_1 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        self.conv2_2 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        self.conv2_3 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n",
    "        self.maxpool2 = layers.MaxPool2D(2, strides=2)\n",
    "\n",
    "        # Convolution Layer with 256 filters and a kernel size of 3.\n",
    "        self.conv3_1 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        self.conv3_2 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "        self.conv3_3 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n",
    "\n",
    "        # Flatten the data to a 1-D vector for the fully connected layer.\n",
    "        self.flatten = layers.Flatten()\n",
    "\n",
    "        # Fully connected layer.\n",
    "        self.fc1 = layers.Dense(1024, activation=tf.nn.relu)\n",
    "        # Apply Dropout (if is_training is False, dropout is not applied).\n",
    "        self.dropout = layers.Dropout(rate=0.5)\n",
    "\n",
    "        # Output layer, class prediction.\n",
    "        self.out = layers.Dense(num_classes)\n",
    "\n",
    "    # Set forward pass.\n",
    "    @tf.function\n",
    "    def call(self, x, is_training=False):\n",
    "        x = self.conv1_1(x)\n",
    "        x = self.conv1_2(x)\n",
    "        x = self.maxpool1(x)\n",
    "        x = self.conv2_1(x)\n",
    "        x = self.conv2_2(x)\n",
    "        x = self.conv2_3(x)\n",
    "        x = self.maxpool2(x)\n",
    "        x = self.conv3_1(x)\n",
    "        x = self.conv3_2(x)\n",
    "        x = self.conv3_3(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.dropout(x, training=is_training)\n",
    "        x = self.out(x)\n",
    "        if not is_training:\n",
    "            # tf cross entropy expect logits without softmax, so only\n",
    "            # apply softmax when not training.\n",
    "            x = tf.nn.softmax(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cross-Entropy Loss.\n",
    "# Note that this will apply 'softmax' to the logits.\n",
    "@tf.function\n",
    "def cross_entropy_loss(x, y):\n",
    "    # Convert labels to int 64 for tf cross-entropy function.\n",
    "    y = tf.cast(y, tf.int64)\n",
    "    # Apply softmax to logits and compute cross-entropy.\n",
    "    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n",
    "    # Average loss across the batch.\n",
    "    return tf.reduce_mean(loss)\n",
    "\n",
    "# Accuracy metric.\n",
    "@tf.function\n",
    "def accuracy(y_pred, y_true):\n",
    "    # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n",
    "    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n",
    "    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n",
    "    \n",
    "\n",
    "@tf.function\n",
    "def backprop(batch_x, batch_y, trainable_variables):\n",
    "    # Wrap computation inside a GradientTape for automatic differentiation.\n",
    "    with tf.GradientTape() as g:\n",
    "        # Forward pass.\n",
    "        pred = conv_net(batch_x, is_training=True)\n",
    "        # Compute loss.\n",
    "        loss = cross_entropy_loss(pred, batch_y)\n",
    "        # Compute gradients.\n",
    "        gradients = g.gradient(loss, trainable_variables)\n",
    "    return gradients\n",
    "\n",
    "# Build the function to average the gradients.\n",
    "@tf.function\n",
    "def average_gradients(tower_grads):\n",
    "    avg_grads = []\n",
    "    for tgrads in zip(*tower_grads):\n",
    "        grads = []\n",
    "        for g in tgrads:\n",
    "            expanded_g = tf.expand_dims(g, 0)\n",
    "            grads.append(expanded_g)\n",
    "        \n",
    "        grad = tf.concat(axis=0, values=grads)\n",
    "        grad = tf.reduce_mean(grad, 0)\n",
    "        \n",
    "        avg_grads.append(grad)\n",
    "        \n",
    "    return avg_grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.device('/cpu:0'):\n",
    "    # Build convnet.\n",
    "    conv_net = ConvNet()\n",
    "    # Stochastic gradient descent optimizer.\n",
    "    optimizer = tf.optimizers.Adam(learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimization process.\n",
    "def run_optimization(x, y):\n",
    "    # Save gradients for all GPUs.\n",
    "    tower_grads = []\n",
    "    # Variables to update, i.e. trainable variables.\n",
    "    trainable_variables = conv_net.trainable_variables\n",
    "\n",
    "    with tf.device('/cpu:0'):\n",
    "        for i in range(num_gpus):\n",
    "            # Split data between GPUs.\n",
    "            gpu_batch_size = int(batch_size/num_gpus)\n",
    "            batch_x = x[i * gpu_batch_size: (i+1) * gpu_batch_size]\n",
    "            batch_y = y[i * gpu_batch_size: (i+1) * gpu_batch_size]\n",
    "            \n",
    "            # Build the neural net on each GPU.\n",
    "            with tf.device('/gpu:%i' % i):\n",
    "                grad = backprop(batch_x, batch_y, trainable_variables)\n",
    "                tower_grads.append(grad)\n",
    "                    \n",
    "                # Last GPU Average gradients from all GPUs.\n",
    "                if i == num_gpus - 1:\n",
    "                    gradients = average_gradients(tower_grads)\n",
    "\n",
    "        # Update vars following gradients.\n",
    "        optimizer.apply_gradients(zip(gradients, trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1, loss: 2.302630, accuracy: 0.101318, speed: 16342.138481 examples/sec\n",
      "step: 20, loss: 2.296755, accuracy: 0.108398, speed: 5355.197204 examples/sec\n",
      "step: 40, loss: 2.216037, accuracy: 0.299072, speed: 12388.080848 examples/sec\n",
      "step: 60, loss: 2.189814, accuracy: 0.362305, speed: 12033.404638 examples/sec\n",
      "step: 80, loss: 2.137831, accuracy: 0.410156, speed: 12189.852065 examples/sec\n",
      "step: 100, loss: 2.102876, accuracy: 0.437744, speed: 12212.349483 examples/sec\n",
      "step: 120, loss: 2.077521, accuracy: 0.460693, speed: 12160.290400 examples/sec\n",
      "step: 140, loss: 2.006775, accuracy: 0.545166, speed: 12202.175380 examples/sec\n",
      "step: 160, loss: 1.994143, accuracy: 0.554443, speed: 12168.070368 examples/sec\n",
      "step: 180, loss: 1.964281, accuracy: 0.597412, speed: 12244.148312 examples/sec\n",
      "step: 200, loss: 1.893395, accuracy: 0.658203, speed: 12197.382402 examples/sec\n",
      "step: 220, loss: 1.880256, accuracy: 0.672363, speed: 12178.323620 examples/sec\n",
      "step: 240, loss: 1.868853, accuracy: 0.676025, speed: 12224.851444 examples/sec\n",
      "step: 260, loss: 1.837151, accuracy: 0.705322, speed: 12101.154436 examples/sec\n",
      "step: 280, loss: 1.799418, accuracy: 0.736816, speed: 12185.701420 examples/sec\n",
      "step: 300, loss: 1.790719, accuracy: 0.755615, speed: 12126.826668 examples/sec\n",
      "step: 320, loss: 1.732242, accuracy: 0.807861, speed: 12229.926783 examples/sec\n",
      "step: 340, loss: 1.732089, accuracy: 0.806885, speed: 12167.651100 examples/sec\n",
      "step: 360, loss: 1.693968, accuracy: 0.835693, speed: 12060.687471 examples/sec\n",
      "step: 380, loss: 1.665804, accuracy: 0.862305, speed: 12130.389108 examples/sec\n",
      "step: 400, loss: 1.627162, accuracy: 0.890381, speed: 12152.946766 examples/sec\n",
      "step: 420, loss: 1.594189, accuracy: 0.920654, speed: 12057.401941 examples/sec\n",
      "step: 440, loss: 1.575212, accuracy: 0.929688, speed: 12196.589206 examples/sec\n",
      "step: 460, loss: 1.569351, accuracy: 0.942383, speed: 12147.345871 examples/sec\n",
      "step: 480, loss: 1.520648, accuracy: 0.974609, speed: 11998.473978 examples/sec\n",
      "step: 500, loss: 1.507439, accuracy: 0.982666, speed: 12152.490287 examples/sec\n",
      "step: 520, loss: 1.495090, accuracy: 0.989746, speed: 12071.718912 examples/sec\n",
      "step: 540, loss: 1.490940, accuracy: 0.989502, speed: 12049.224039 examples/sec\n",
      "step: 560, loss: 1.476727, accuracy: 0.996338, speed: 12134.827424 examples/sec\n",
      "step: 580, loss: 1.475038, accuracy: 0.995850, speed: 12128.228532 examples/sec\n",
      "step: 600, loss: 1.469776, accuracy: 0.997559, speed: 12113.386949 examples/sec\n",
      "step: 620, loss: 1.466832, accuracy: 0.999756, speed: 11939.016031 examples/sec\n",
      "step: 640, loss: 1.466991, accuracy: 0.999023, speed: 12095.815773 examples/sec\n",
      "step: 660, loss: 1.466177, accuracy: 0.999023, speed: 12035.037908 examples/sec\n",
      "step: 680, loss: 1.465074, accuracy: 0.999512, speed: 11789.118097 examples/sec\n",
      "step: 700, loss: 1.464655, accuracy: 0.999512, speed: 11965.087437 examples/sec\n",
      "step: 720, loss: 1.465109, accuracy: 0.999512, speed: 11855.853520 examples/sec\n",
      "step: 740, loss: 1.465021, accuracy: 0.999023, speed: 11774.901096 examples/sec\n",
      "step: 760, loss: 1.463057, accuracy: 1.000000, speed: 11930.138289 examples/sec\n",
      "step: 780, loss: 1.462609, accuracy: 1.000000, speed: 11766.752011 examples/sec\n",
      "step: 800, loss: 1.462320, accuracy: 0.999756, speed: 11744.213314 examples/sec\n",
      "step: 820, loss: 1.462975, accuracy: 1.000000, speed: 11700.815885 examples/sec\n",
      "step: 840, loss: 1.462328, accuracy: 1.000000, speed: 11759.141371 examples/sec\n",
      "step: 860, loss: 1.462561, accuracy: 1.000000, speed: 11650.397252 examples/sec\n",
      "step: 880, loss: 1.462608, accuracy: 0.999512, speed: 11581.170575 examples/sec\n",
      "step: 900, loss: 1.462178, accuracy: 0.999756, speed: 11562.545711 examples/sec\n",
      "step: 920, loss: 1.461582, accuracy: 1.000000, speed: 11616.172231 examples/sec\n",
      "step: 940, loss: 1.462402, accuracy: 1.000000, speed: 11709.561795 examples/sec\n",
      "step: 960, loss: 1.462436, accuracy: 1.000000, speed: 11629.547741 examples/sec\n",
      "step: 980, loss: 1.462415, accuracy: 1.000000, speed: 11623.658645 examples/sec\n",
      "step: 1000, loss: 1.461925, accuracy: 1.000000, speed: 11579.716701 examples/sec\n"
     ]
    }
   ],
   "source": [
    "# Run training for the given number of steps.\n",
    "ts = time.time()\n",
    "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n",
    "    # Run the optimization to update W and b values.\n",
    "    run_optimization(batch_x, batch_y)\n",
    "    \n",
    "    if step % display_step == 0 or step == 1:\n",
    "        dt = time.time() - ts\n",
    "        speed = batch_size * display_step / dt\n",
    "        pred = conv_net(batch_x)\n",
    "        loss = cross_entropy_loss(pred, batch_y)\n",
    "        acc = accuracy(pred, batch_y)\n",
    "        print(\"step: %i, loss: %f, accuracy: %f, speed: %f examples/sec\" % (step, loss, acc, speed))\n",
    "        ts = time.time()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
